今天我們要實作GAN,但不像以前的AutoEncoder Model,GAN大部分是使用捲積層,而非像之前使用的全連結層,所以經過網路上大師們的建議,由於在訓練GAN會很不穩定,因此一些Layers和激活函數都要特別的注意。今天的實作我是參考https://colab.research.google.com/drive/1hNMJ1C3ARYud-6UDqKGYx12cGZ9ULDZ-#scrollTo=YgH_d6fNVuEw
Generator主要是圖片產生器,透過tf.layers.Conv2DTranspose,把特徵還原成照片
class Generator(keras.Model):
def __init__(self):
super(Generator,self).__init__()
#encoder
self.fc_layer_1 = layers.Dense(3*3*512)
self.conv_1 = layers.Conv2DTranspose(256,3,3,'valid')
self.bn_1 = layers.BatchNormalization()
self.conv_2 = layers.Conv2DTranspose(128,5,2,'valid')
self.bn_2 = layers.BatchNormalization()
self.conv_3 = layers.Conv2DTranspose(3,4,3,'valid')
def call(self, inputs, training=None):
x = self.fc_layer_1(inputs)
x = tf.reshape(x,[-1,3,3,512])
x = tf.nn.relu(x)
x = self.bn_1(self.conv_1(x),training=training)
x = self.bn_2(self.conv_2(x),training=training)
x = self.conv_3(x)
x = tf.tanh(x)
return x
他是一個圖片分類器,用以判斷 Generator 產生圖片的好壞
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator,self).__init__()
self.conv_1 = layers.Conv2D(64,5,3,'valid')
self.conv_2 = layers.Conv2D(128,5,3,'valid')
self.bn_1 = layers.BatchNormalization()
self.conv_3 = layers.Conv2D(256,5,3,'valid')
self.bn_2 = layers.BatchNormalization()
self.flatten = layers.Flatten()
self.fc_layer = layers.Dense(1)
![https://ithelp.ithome.com.tw/upload/images/20201009/20130246QRoobECSo6.png](https://ithelp.ithome.com.tw/upload/images/20201009/20130246QRoobECSo6.png)
def call(self, inputs, training=None):
x = tf.nn.leaky_relu(self.conv_1(inputs))
x = tf.nn.leaky_relu(self.bn_1(self.conv_2(x),training=training))
x = tf.nn.leaky_relu(self.bn_2(self.conv_3(x),training=training))
x = self.flatten(x)
x = self.fc_layer(x)
return x
g = Generator()
d = Discriminator()
x = tf.random.normal([1,64,64,3])
z = tf.random.normal([1,100])
prob = g(x)
print(prob)
out = d(x)
print(out.shape)
今天只是簡單的GAN模型建立,明天才會把圖畫出來,終於要完成30天發文了!!!!!。